{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "machine_shape": "hm",
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "gpuClass": "standard",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "7716ce48c050461ab4c41b8cbc8c3720": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HBoxModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HBoxView",
            "box_style": "",
            "children": [
              "IPY_MODEL_b9cf72c6e7ca454cbb07a0087be40bde",
              "IPY_MODEL_39892f4773e74558b1a38487c11a3020",
              "IPY_MODEL_2efe9bc4534a41fb8e8b86f2270b719c"
            ],
            "layout": "IPY_MODEL_fc913b2d053749b989b0a5024f99f2ea"
          }
        },
        "b9cf72c6e7ca454cbb07a0087be40bde": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_17c413725e5846e386946103f87d6a74",
            "placeholder": "​",
            "style": "IPY_MODEL_6b8bd2625c7940ea900b153f87e85109",
            "value": "Dl Completed...: 100%"
          }
        },
        "39892f4773e74558b1a38487c11a3020": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "FloatProgressModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "ProgressView",
            "bar_style": "success",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_cddba9f47ce140d29569bf2b8afa3f83",
            "max": 5,
            "min": 0,
            "orientation": "horizontal",
            "style": "IPY_MODEL_cee0e6c0e4794f4fb9e1ddedacf7ce00",
            "value": 5
          }
        },
        "2efe9bc4534a41fb8e8b86f2270b719c": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "model_module_version": "1.5.0",
          "state": {
            "_dom_classes": [],
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "HTMLModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/controls",
            "_view_module_version": "1.5.0",
            "_view_name": "HTMLView",
            "description": "",
            "description_tooltip": null,
            "layout": "IPY_MODEL_43a7d99f797442d58770bba064b4de07",
            "placeholder": "​",
            "style": "IPY_MODEL_79e4831b59fb4dda9caa95dd5de85500",
            "value": " 5/5 [00:00&lt;00:00, 15.75 file/s]"
          }
        },
        "fc913b2d053749b989b0a5024f99f2ea": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "17c413725e5846e386946103f87d6a74": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "6b8bd2625c7940ea900b153f87e85109": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        },
        "cddba9f47ce140d29569bf2b8afa3f83": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "cee0e6c0e4794f4fb9e1ddedacf7ce00": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "ProgressStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "bar_color": null,
            "description_width": ""
          }
        },
        "43a7d99f797442d58770bba064b4de07": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "model_module_version": "1.2.0",
          "state": {
            "_model_module": "@jupyter-widgets/base",
            "_model_module_version": "1.2.0",
            "_model_name": "LayoutModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "LayoutView",
            "align_content": null,
            "align_items": null,
            "align_self": null,
            "border": null,
            "bottom": null,
            "display": null,
            "flex": null,
            "flex_flow": null,
            "grid_area": null,
            "grid_auto_columns": null,
            "grid_auto_flow": null,
            "grid_auto_rows": null,
            "grid_column": null,
            "grid_gap": null,
            "grid_row": null,
            "grid_template_areas": null,
            "grid_template_columns": null,
            "grid_template_rows": null,
            "height": null,
            "justify_content": null,
            "justify_items": null,
            "left": null,
            "margin": null,
            "max_height": null,
            "max_width": null,
            "min_height": null,
            "min_width": null,
            "object_fit": null,
            "object_position": null,
            "order": null,
            "overflow": null,
            "overflow_x": null,
            "overflow_y": null,
            "padding": null,
            "right": null,
            "top": null,
            "visibility": null,
            "width": null
          }
        },
        "79e4831b59fb4dda9caa95dd5de85500": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "model_module_version": "1.5.0",
          "state": {
            "_model_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_model_name": "DescriptionStyleModel",
            "_view_count": null,
            "_view_module": "@jupyter-widgets/base",
            "_view_module_version": "1.2.0",
            "_view_name": "StyleView",
            "description_width": ""
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "## Generating random initializations for a neural network"
      ],
      "metadata": {
        "id": "i_PYbvJuEbbH"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "LAYER_SIZES = [200*200*3, 2048, 1024, 2]\n",
        "PARAM_SCALE = 0.01"
      ],
      "metadata": {
        "id": "adQ9pbzJ94Hk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from jax import random"
      ],
      "metadata": {
        "id": "k9xOkoxB-cK0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def random_layer_params(m, n, key, scale=1e-2):\n",
        "  \"\"\"A helper function to randomly initialize weights and biases of a dense layer\"\"\"\n",
        "  w_key, b_key = random.split(key)\n",
        "  return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n",
        "\n",
        "def init_network_params(sizes, key=random.PRNGKey(0), scale=0.01):\n",
        "  \"\"\"Initialize all layers for a fully-connected neural network with given sizes\"\"\"\n",
        "  keys = random.split(key, len(sizes)-1)\n",
        "  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]"
      ],
      "metadata": {
        "id": "mFvGPufs-GPM",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "738e3c3e-5629-4117-d09c-0401428c73d7"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "key = random.PRNGKey(42)"
      ],
      "metadata": {
        "id": "BsHFmOWxA8Lc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "params = init_network_params(LAYER_SIZES, key, scale=PARAM_SCALE)"
      ],
      "metadata": {
        "id": "yW0rZY1TFPjV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for i,layer in enumerate(params):\n",
        "  w,b = layer\n",
        "  print(i, w.shape, b.shape)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6j9B80QTlzZH",
        "outputId": "c9e4e981-9104-4b09-bb19-95e0110d11a4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0 (2048, 120000) (2048,)\n",
            "1 (1024, 2048) (1024,)\n",
            "2 (2, 1024) (2,)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "shapes = jax.tree_util.tree_map(lambda p: p.shape, params)"
      ],
      "metadata": {
        "id": "C95WsgsPdxze"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "for i,shape in enumerate(shapes):\n",
        "  print(i, shape)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "6eu9R09kpoSN",
        "outputId": "fc6a9870-b0f1-46f6-bb1b-65a0bc133e64"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "0 ((2048, 120000), (2048,))\n",
            "1 ((1024, 2048), (1024,))\n",
            "2 ((2, 1024), (2,))\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_leaves(params)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "S5hxbvjQls5r",
        "outputId": "d10e7b4c-e127-47ed-d285-a50e10447916"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[Array([[-0.01233545,  0.00821559, -0.00978333, ..., -0.00762261,\n",
              "          0.01153923, -0.00699568],\n",
              "        [ 0.00595356, -0.00548696,  0.00862382, ..., -0.00660049,\n",
              "         -0.01387328,  0.00377337],\n",
              "        [-0.00239754, -0.01310708,  0.01655764, ...,  0.00812733,\n",
              "         -0.0122619 ,  0.0073874 ],\n",
              "        ...,\n",
              "        [ 0.01216634, -0.01617356, -0.0034067 , ...,  0.00477375,\n",
              "         -0.00057253,  0.00784415],\n",
              "        [-0.01213108, -0.00440847, -0.02979285, ...,  0.00520762,\n",
              "         -0.0157708 ,  0.00563094],\n",
              "        [-0.00177029, -0.00257568,  0.01720736, ...,  0.00065184,\n",
              "         -0.00535367, -0.00308625]], dtype=float32),\n",
              " Array([ 0.00885022,  0.0077482 , -0.00685802, ..., -0.0108596 ,\n",
              "        -0.02526451,  0.00504387], dtype=float32),\n",
              " Array([[ 0.00833544, -0.00892102,  0.01756026, ...,  0.00956038,\n",
              "          0.02225046, -0.00698299],\n",
              "        [-0.00055303, -0.00299709,  0.0233378 , ..., -0.00282513,\n",
              "          0.00495245, -0.00278503],\n",
              "        [-0.0128897 , -0.01443312,  0.00437023, ...,  0.0099717 ,\n",
              "          0.01219303,  0.00933401],\n",
              "        ...,\n",
              "        [-0.01043672,  0.00142216, -0.00108644, ...,  0.00589371,\n",
              "          0.00438441, -0.00884235],\n",
              "        [-0.00751525,  0.00096751, -0.00141621, ...,  0.01048672,\n",
              "         -0.00359207, -0.00167823],\n",
              "        [-0.00822164,  0.02248102,  0.00858274, ..., -0.00129686,\n",
              "          0.01258506,  0.00516878]], dtype=float32),\n",
              " Array([-0.00153434, -0.00426756, -0.0082617 , ...,  0.00311023,\n",
              "        -0.01695804, -0.00391124], dtype=float32),\n",
              " Array([[ 0.00113224, -0.00548538,  0.00278324, ..., -0.01151069,\n",
              "         -0.01429665, -0.00419622],\n",
              "        [ 0.00064889, -0.00967699, -0.00876216, ..., -0.00548893,\n",
              "          0.0090474 , -0.00056804]], dtype=float32),\n",
              " Array([ 0.01284917, -0.01317841], dtype=float32)]"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Leaves and nodes"
      ],
      "metadata": {
        "id": "5W-BPZzeGPrx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import numpy as np\n",
        "import jax.numpy as jnp\n",
        "import collections"
      ],
      "metadata": {
        "id": "MDxnkU136Jzz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "Point = collections.namedtuple('Point', ['x', 'y'])"
      ],
      "metadata": {
        "id": "GNkF1Go37Tjx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "example_pytree = [\n",
        "    {\n",
        "        'a': [1, 2, 3],\n",
        "        'b': jnp.array([1, 2, 3]),\n",
        "        'c': np.array([1, 2, 3])\n",
        "    },\n",
        "    [42, [44, 46], None],\n",
        "    31337,\n",
        "    (50, (60, 70)),\n",
        "    Point(640, 480),\n",
        "    collections.OrderedDict([('a', 100), ('b', 200)]),\n",
        "    'some string'\n",
        "]"
      ],
      "metadata": {
        "id": "r81ppWR_23Kq"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_leaves(example_pytree)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "M8BVdFIX6IKp",
        "outputId": "d401bcc1-61bb-4cc9-e852-df5361e2b54a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[1,\n",
              " 2,\n",
              " 3,\n",
              " Array([1, 2, 3], dtype=int32),\n",
              " array([1, 2, 3]),\n",
              " 42,\n",
              " 44,\n",
              " 46,\n",
              " 31337,\n",
              " 50,\n",
              " 60,\n",
              " 70,\n",
              " 640,\n",
              " 480,\n",
              " 100,\n",
              " 200,\n",
              " 'some string']"
            ]
          },
          "metadata": {},
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Back to the MLP example from Chapter 2"
      ],
      "metadata": {
        "id": "1Ax3_LM2Gdvx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import tensorflow as tf\n",
        "# Ensure TF does not see GPU and grab all GPU memory.\n",
        "#tf.config.set_visible_devices([], device_type='GPU')\n",
        "\n",
        "import tensorflow_datasets as tfds\n",
        "\n",
        "data_dir = '/tmp/tfds'\n",
        "\n",
        "# as_supervised=True gives us the (image, label) as a tuple instead of a dict\n",
        "data, info = tfds.load(name=\"mnist\",\n",
        "                       data_dir=data_dir,\n",
        "                       as_supervised=True,\n",
        "                       with_info=True)\n",
        "\n",
        "data_train = data['train']\n",
        "data_test  = data['test']"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 84,
          "referenced_widgets": [
            "7716ce48c050461ab4c41b8cbc8c3720",
            "b9cf72c6e7ca454cbb07a0087be40bde",
            "39892f4773e74558b1a38487c11a3020",
            "2efe9bc4534a41fb8e8b86f2270b719c",
            "fc913b2d053749b989b0a5024f99f2ea",
            "17c413725e5846e386946103f87d6a74",
            "6b8bd2625c7940ea900b153f87e85109",
            "cddba9f47ce140d29569bf2b8afa3f83",
            "cee0e6c0e4794f4fb9e1ddedacf7ce00",
            "43a7d99f797442d58770bba064b4de07",
            "79e4831b59fb4dda9caa95dd5de85500"
          ]
        },
        "id": "cQKibdtJGV_K",
        "outputId": "17325aaf-e8a9-4080-b4b1-294e7da23545"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Downloading and preparing dataset Unknown size (download: Unknown size, generated: Unknown size, total: Unknown size) to /tmp/tfds/mnist/3.0.1...\n"
          ]
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Dl Completed...:   0%|          | 0/5 [00:00<?, ? file/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "7716ce48c050461ab4c41b8cbc8c3720"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Dataset mnist downloaded and prepared to /tmp/tfds/mnist/3.0.1. Subsequent calls will reuse this data.\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "HEIGHT = 28\n",
        "WIDTH  = 28\n",
        "CHANNELS = 1\n",
        "NUM_PIXELS = HEIGHT * WIDTH * CHANNELS\n",
        "NUM_LABELS = info.features['label'].num_classes"
      ],
      "metadata": {
        "id": "OIJj7-JAGm9R"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def preprocess(img, label):\n",
        "  \"\"\"Resize and preprocess images.\"\"\"\n",
        "  return (tf.cast(img, tf.float32)/255.0), label\n",
        "\n",
        "train_data = tfds.as_numpy(data_train.map(preprocess).batch(32).prefetch(1))\n",
        "test_data  = tfds.as_numpy(data_test.map(preprocess).batch(32).prefetch(1))"
      ],
      "metadata": {
        "id": "QHptnydPGrMp"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "LAYER_SIZES = [28*28, 512, 10]\n",
        "PARAM_SCALE = 0.01"
      ],
      "metadata": {
        "id": "Mk9TirwrGu_Q"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import jax\n",
        "import jax.numpy as jnp\n",
        "from jax import grad, jit, vmap, value_and_grad\n",
        "from jax import random\n",
        "from jax.nn import swish, logsumexp, one_hot"
      ],
      "metadata": {
        "id": "5iNLaGrjGwBy"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def init_network_params(sizes, key=random.PRNGKey(0), scale=1e-2):\n",
        "  \"\"\"Initialize all layers for a fully-connected neural network with given sizes\"\"\"\n",
        "\n",
        "  def random_layer_params(m, n, key, scale=1e-2):\n",
        "    \"\"\"A helper function to randomly initialize weights and biases of a dense layer\"\"\"\n",
        "    w_key, b_key = random.split(key)\n",
        "    return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n",
        "\n",
        "  keys = random.split(key, len(sizes))\n",
        "  return [random_layer_params(m, n, k, scale) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]"
      ],
      "metadata": {
        "id": "bdoPWkiNG2Sz"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "init_params = init_network_params(LAYER_SIZES, random.PRNGKey(0), scale=PARAM_SCALE)"
      ],
      "metadata": {
        "id": "7vuYS9gGG4Lh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def predict(params, image):\n",
        "  \"\"\"Function for per-example predictions.\"\"\"\n",
        "  activations = image\n",
        "  for w, b in params[:-1]:\n",
        "    outputs = jnp.dot(w, activations) + b\n",
        "    activations = swish(outputs)\n",
        "\n",
        "  final_w, final_b = params[-1]\n",
        "  logits = jnp.dot(final_w, activations) + final_b\n",
        "  return logits"
      ],
      "metadata": {
        "id": "nMMMtOUsG5X6"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "batched_predict = vmap(predict, in_axes=(None, 0))"
      ],
      "metadata": {
        "id": "G-Yt4VDfG6g0"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "INIT_LR = 1.0\n",
        "DECAY_RATE = 0.95\n",
        "DECAY_STEPS = 5"
      ],
      "metadata": {
        "id": "XfrpAxSxG-VA"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def loss(params, images, targets):\n",
        "  \"\"\"Categorical cross entropy loss function.\"\"\"\n",
        "  logits = batched_predict(params, images)\n",
        "  log_preds = logits - logsumexp(logits)\n",
        "  return -jnp.mean(targets*log_preds)\n",
        "\n",
        "@jax.jit\n",
        "def update(params, x, y, epoch_number):\n",
        "  print(f\"Params shapes: {jax.tree_util.tree_map(lambda p: p.shape, params)}\")\n",
        "  loss_value, grads = value_and_grad(loss)(params, x, y)\n",
        "  print(f\"Grads shapes: {jax.tree_util.tree_map(lambda p: p.shape, grads)}\")\n",
        "  lr = INIT_LR * DECAY_RATE ** (epoch_number / DECAY_STEPS)\n",
        "  return [(w - lr * dw, b - lr * db)\n",
        "          for (w, b), (dw, db) in zip(params, grads)], loss_value"
      ],
      "metadata": {
        "id": "9M7ra3JNHAEV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "x, y = next(iter(train_data))"
      ],
      "metadata": {
        "id": "dtCJ1e5ZHB2C"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "x = jnp.reshape(x, (len(x), NUM_PIXELS))\n",
        "y = one_hot(y, NUM_LABELS)"
      ],
      "metadata": {
        "id": "GxlfpSgoHsgw"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "params, loss_value = update(init_params, x, y, 0)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "K-QcJVIWH75l",
        "outputId": "cb0c2669-a440-4105-c665-66cdad10c7e6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "Params shapes: [((512, 784), (512,)), ((10, 512), (10,))]\n",
            "Grads shapes: [((512, 784), (512,)), ((10, 512), (10,))]\n"
          ]
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Working with pytrees"
      ],
      "metadata": {
        "id": "BJNad-_cPNg9"
      }
    },
    {
      "cell_type": "markdown",
      "source": [
        "Flatten/unflatten"
      ],
      "metadata": {
        "id": "wW320CFHp3SE"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "params = init_network_params(LAYER_SIZES, key, scale=PARAM_SCALE)"
      ],
      "metadata": {
        "id": "dRMsOme2Pbi5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "scaled_params = jax.tree_util.tree_map(lambda p: 10*p, params)"
      ],
      "metadata": {
        "id": "BNrzasLWITED"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "some_pytree = [\n",
        "    [1,1,1],\n",
        "    [\n",
        "        [10,10,10], [20, 20]\n",
        "    ]\n",
        "]"
      ],
      "metadata": {
        "id": "AJPxeNpbPiPQ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_map(lambda p: p+1, some_pytree)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9vqnMrLuWNWC",
        "outputId": "1dc0444d-5806-415a-b2df-326e868afbc3"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[[2, 2, 2], [[11, 11, 11], [21, 21]]]"
            ]
          },
          "metadata": {},
          "execution_count": 23
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "leaves, struct = jax.tree_util.tree_flatten(some_pytree)"
      ],
      "metadata": {
        "id": "tcwv-6fkWSrr"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "leaves"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "s5r7oTlVWrYD",
        "outputId": "2d173f35-1da8-4a98-8b1d-010b5aa18f36"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[1, 1, 1, 10, 10, 10, 20, 20]"
            ]
          },
          "metadata": {},
          "execution_count": 25
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "struct"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BNFDeoxcWsZL",
        "outputId": "cbd09bd8-a27a-454b-89eb-097606c7a323"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "PyTreeDef([[*, *, *], [[*, *, *], [*, *]]])"
            ]
          },
          "metadata": {},
          "execution_count": 26
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "updated_leaves = map(lambda x: x+1, leaves)"
      ],
      "metadata": {
        "id": "fkw__kDOWtUC"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_unflatten(struct, updated_leaves)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-2VRfskbWyD_",
        "outputId": "1aa809a4-4ce5-4a54-beff-2312f7339406"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[[2, 2, 2], [[11, 11, 11], [21, 21]]]"
            ]
          },
          "metadata": {},
          "execution_count": 30
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Flatten/unflatten using a 1D array"
      ],
      "metadata": {
        "id": "UUHWkk9Sp6hd"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "from jax.flatten_util import ravel_pytree"
      ],
      "metadata": {
        "id": "n_1f-B94W5h4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "leaves, unflatten_func = ravel_pytree(some_pytree)"
      ],
      "metadata": {
        "id": "3j66435SZUMe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "leaves"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sv_xBmdUZX0-",
        "outputId": "5af64bce-8dce-4eeb-d59a-c5b877750820"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Array([ 1,  1,  1, 10, 10, 10, 20, 20], dtype=int32)"
            ]
          },
          "metadata": {},
          "execution_count": 38
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "unflatten_func"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "eI6X3AEfZYr1",
        "outputId": "a5b9300a-1ba7-431e-f35b-34ba6de169bf"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<function jax._src.flatten_util.ravel_pytree.<locals>.<lambda>(flat)>"
            ]
          },
          "metadata": {},
          "execution_count": 36
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "unflatten_func(leaves)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "VErdU6M5ZbgT",
        "outputId": "f567d61f-dfe0-4ceb-f4d5-880ae25c7a0d"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[[Array(1, dtype=int32), Array(1, dtype=int32), Array(1, dtype=int32)],\n",
              " [[Array(10, dtype=int32), Array(10, dtype=int32), Array(10, dtype=int32)],\n",
              "  [Array(20, dtype=int32), Array(20, dtype=int32)]]]"
            ]
          },
          "metadata": {},
          "execution_count": 37
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Reducing a tree"
      ],
      "metadata": {
        "id": "P_Rq-BDCp-IA"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_reduce(lambda acc,value: acc+value, some_pytree, initializer=0)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZdKqTRCLZls3",
        "outputId": "44b60578-5523-428d-a307-cd350e3f20c4"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "73"
            ]
          },
          "metadata": {},
          "execution_count": 43
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Transposing a pytree"
      ],
      "metadata": {
        "id": "Yxt2szBQwLUN"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "import math\n",
        "from collections import namedtuple\n",
        "Point = namedtuple('Point', ['x', 'y'])"
      ],
      "metadata": {
        "id": "0-axCiAYqM6A"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "points = [\n",
        "    Point(0.0, 0.0),\n",
        "    Point(3.0, 0.0),\n",
        "    Point(0.0, 4.0)\n",
        "]"
      ],
      "metadata": {
        "id": "MDtJHaCqwMzJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def rotate_point(p, theta):\n",
        "  x = p.x * math.cos(theta) - p.y * math.sin(theta)\n",
        "  y = p.x * math.sin(theta) + p.y * math.cos(theta)\n",
        "  return Point(x,y)"
      ],
      "metadata": {
        "id": "nNZA1RKQwg69"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "rotate_point(points[1], math.pi)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "lwNaiUegx0sy",
        "outputId": "a00fc3d7-ffd2-4d01-95a3-670973c35c33"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Point(x=-3.0, y=3.6739403974420594e-16)"
            ]
          },
          "metadata": {},
          "execution_count": 50
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.vmap(rotate_point, in_axes=(0, None))(points, math.pi)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 537
        },
        "id": "Cz4Xa0s3x4ta",
        "outputId": "23ae6813-627a-453a-cb8f-d3701573e110"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "error",
          "ename": "ValueError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_get_axis_size\u001b[0;34m(name, shape, axis)\u001b[0m\n\u001b[1;32m   1792\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1793\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1794\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mIndexError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mIndexError\u001b[0m: tuple index out of range",
            "\nThe above exception was the direct cause of the following exception:\n",
            "\u001b[0;31mUnfilteredStackTrace\u001b[0m                      Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-54-291926665e1e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrotate_point\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoints\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    165\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    167\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mvmap_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1771\u001b[0m     axis_size_ = (axis_size if axis_size is not None else\n\u001b[0;32m-> 1772\u001b[0;31m                   _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, \"vmap\"))\n\u001b[0m\u001b[1;32m   1773\u001b[0m     out_flat = batching.batch(\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_mapped_axis_size\u001b[0;34m(fn, tree, vals, dims, name)\u001b[0m\n\u001b[1;32m   1801\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1802\u001b[0;31m   sizes = core.dedup_referents(_get_axis_size(name, np.shape(x), d)\n\u001b[0m\u001b[1;32m   1803\u001b[0m                                for x, d in zip(vals, dims) if d is not None)\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/core.py\u001b[0m in \u001b[0;36mdedup_referents\u001b[0;34m(itr)\u001b[0m\n\u001b[1;32m   1229\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdedup_referents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitr\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1230\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mHashableWrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_referent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mx\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mitr\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/core.py\u001b[0m in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m   1229\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdedup_referents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitr\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1230\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mHashableWrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_referent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mx\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mitr\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m   1801\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1802\u001b[0;31m   sizes = core.dedup_referents(_get_axis_size(name, np.shape(x), d)\n\u001b[0m\u001b[1;32m   1803\u001b[0m                                for x, d in zip(vals, dims) if d is not None)\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_get_axis_size\u001b[0;34m(name, shape, axis)\u001b[0m\n\u001b[1;32m   1796\u001b[0m       \u001b[0;31m# TODO(mattjj): better error message here\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1797\u001b[0;31m       raise ValueError(\n\u001b[0m\u001b[1;32m   1798\u001b[0m           \u001b[0;34mf\"{name} was requested to map its argument along axis {axis}, \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mUnfilteredStackTrace\u001b[0m: ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------",
            "\nThe above exception was the direct cause of the following exception:\n",
            "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-54-291926665e1e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrotate_point\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoints\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;31mValueError\u001b[0m: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.vmap(rotate_point, in_axes=(0, None))(jnp.array(points), math.pi)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 467
        },
        "id": "N71ehNo8-Oxb",
        "outputId": "c1875d9c-0375-4a29-f7fc-0e8330332b60"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "error",
          "ename": "AttributeError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mUnfilteredStackTrace\u001b[0m                      Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-76-a73c4fc58702>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrotate_point\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoints\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    165\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    167\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mvmap_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1772\u001b[0m                   _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, \"vmap\"))\n\u001b[0;32m-> 1773\u001b[0;31m     out_flat = batching.batch(\n\u001b[0m\u001b[1;32m   1774\u001b[0m         \u001b[0mflat_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis_name\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis_size_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_axes_flat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/linear_util.py\u001b[0m in \u001b[0;36mcall_wrapped\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    164\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m       \u001b[0mans\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mdict\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparams\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    166\u001b[0m     \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-47-95bef6c78720>\u001b[0m in \u001b[0;36mrotate_point\u001b[0;34m(p, theta)\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrotate_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m   \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m   \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/core.py\u001b[0m in \u001b[0;36m__getattr__\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m    695\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 696\u001b[0;31m       \u001b[0mattr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maval\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    697\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mKeyError\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0merr\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mUnfilteredStackTrace\u001b[0m: AttributeError: 'ShapedArray' object has no attribute 'x'\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------",
            "\nThe above exception was the direct cause of the following exception:\n",
            "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-76-a73c4fc58702>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrotate_point\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoints\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m<ipython-input-47-95bef6c78720>\u001b[0m in \u001b[0;36mrotate_point\u001b[0;34m(p, theta)\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mrotate_point\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m   \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      3\u001b[0m   \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcos\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtheta\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m   \u001b[0;32mreturn\u001b[0m \u001b[0mPoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mAttributeError\u001b[0m: 'ShapedArray' object has no attribute 'x'"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "points"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "OY1tkV_2yEgz",
        "outputId": "626ee0b9-33ae-40da-d035-58d9b95a8216"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[Point(x=0.0, y=0.0), Point(x=3.0, y=0.0), Point(x=0.0, y=4.0)]"
            ]
          },
          "metadata": {},
          "execution_count": 55
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_structure(points)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XobK8BYaytYQ",
        "outputId": "19cf9153-70db-4ba5-fcb5-5dc72f6dfee1"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "PyTreeDef([CustomNode(namedtuple[Point], [*, *]), CustomNode(namedtuple[Point], [*, *]), CustomNode(namedtuple[Point], [*, *])])"
            ]
          },
          "metadata": {},
          "execution_count": 58
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_structure(points[0])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "MtvXMYhuywQq",
        "outputId": "99398059-bf8f-48a1-f7bc-340de6eaf3a2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "PyTreeDef(CustomNode(namedtuple[Point], [*, *]))"
            ]
          },
          "metadata": {},
          "execution_count": 59
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_transpose(\n",
        "  outer_treedef = jax.tree_util.tree_structure(points),\n",
        "  inner_treedef = jax.tree_util.tree_structure(points[0]),\n",
        "  pytree_to_transpose=points\n",
        ")"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 426
        },
        "id": "6NEq1SJnye9s",
        "outputId": "699af899-fbd9-45d4-a919-78bc4459106c"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "error",
          "ename": "TypeError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-64-41a8ff5d94d4>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m jax.tree_util.tree_transpose(\n\u001b[0m\u001b[1;32m      2\u001b[0m   \u001b[0mouter_treedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_structure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoints\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      3\u001b[0m   \u001b[0minner_treedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_structure\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoints\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      4\u001b[0m   \u001b[0mpytree_to_transpose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpoints\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m )\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/tree_util.py\u001b[0m in \u001b[0;36mtree_transpose\u001b[0;34m(outer_treedef, inner_treedef, pytree_to_transpose)\u001b[0m\n\u001b[1;32m    223\u001b[0m   \u001b[0;32mif\u001b[0m \u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnum_leaves\u001b[0m \u001b[0;34m!=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0minner_size\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mouter_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    224\u001b[0m     \u001b[0mexpected_treedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mouter_treedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompose\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minner_treedef\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 225\u001b[0;31m     \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"Mismatch\\n{treedef}\\n != \\n{expected_treedef}\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    226\u001b[0m   \u001b[0miter_flat\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0miter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mflat\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    227\u001b[0m   \u001b[0mlol\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miter_flat\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m_\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minner_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0m__\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mouter_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mTypeError\u001b[0m: Mismatch\nPyTreeDef([CustomNode(namedtuple[Point], [*, *]), CustomNode(namedtuple[Point], [*, *]), CustomNode(namedtuple[Point], [*, *])])\n != \nPyTreeDef([CustomNode(namedtuple[Point], [CustomNode(namedtuple[Point], [*, *]), CustomNode(namedtuple[Point], [*, *])]), CustomNode(namedtuple[Point], [CustomNode(namedtuple[Point], [*, *]), CustomNode(namedtuple[Point], [*, *])]), CustomNode(namedtuple[Point], [CustomNode(namedtuple[Point], [*, *]), CustomNode(namedtuple[Point], [*, *])])])"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_structure([0 for p in points])"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "jodfn_rO0am8",
        "outputId": "54fcf1c1-b2ae-44f6-88c6-cfa1cfcc3c2e"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "PyTreeDef([*, *, *])"
            ]
          },
          "metadata": {},
          "execution_count": 66
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "points_t = jax.tree_util.tree_transpose(\n",
        "  outer_treedef = jax.tree_util.tree_structure([0 for p in points]),\n",
        "  inner_treedef = jax.tree_util.tree_structure(points[0]),\n",
        "  pytree_to_transpose=points\n",
        ")"
      ],
      "metadata": {
        "id": "X4dNI2UZy2vh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "points_t"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "AflrUcRH0qua",
        "outputId": "997f63a1-34f3-40f9-c466-052ba9d6c120"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Point(x=[0.0, 3.0, 0.0], y=[0.0, 0.0, 4.0])"
            ]
          },
          "metadata": {},
          "execution_count": 68
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.vmap(rotate_point, in_axes=(0, None))(points_t, math.pi)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 537
        },
        "id": "AtY2AW9a0QIl",
        "outputId": "c2103e2f-a4ff-41bf-90ac-c0f13f8515c3"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "error",
          "ename": "ValueError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mIndexError\u001b[0m                                Traceback (most recent call last)",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_get_axis_size\u001b[0;34m(name, shape, axis)\u001b[0m\n\u001b[1;32m   1792\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1793\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1794\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mIndexError\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mIndexError\u001b[0m: tuple index out of range",
            "\nThe above exception was the direct cause of the following exception:\n",
            "\u001b[0;31mUnfilteredStackTrace\u001b[0m                      Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-69-06fa7acb6de1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrotate_point\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoints_t\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/traceback_util.py\u001b[0m in \u001b[0;36mreraise_with_filtered_traceback\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m    165\u001b[0m     \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 166\u001b[0;31m       \u001b[0;32mreturn\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    167\u001b[0m     \u001b[0;32mexcept\u001b[0m \u001b[0mException\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36mvmap_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m   1771\u001b[0m     axis_size_ = (axis_size if axis_size is not None else\n\u001b[0;32m-> 1772\u001b[0;31m                   _mapped_axis_size(fun, in_tree, args_flat, in_axes_flat, \"vmap\"))\n\u001b[0m\u001b[1;32m   1773\u001b[0m     out_flat = batching.batch(\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_mapped_axis_size\u001b[0;34m(fn, tree, vals, dims, name)\u001b[0m\n\u001b[1;32m   1801\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1802\u001b[0;31m   sizes = core.dedup_referents(_get_axis_size(name, np.shape(x), d)\n\u001b[0m\u001b[1;32m   1803\u001b[0m                                for x, d in zip(vals, dims) if d is not None)\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/core.py\u001b[0m in \u001b[0;36mdedup_referents\u001b[0;34m(itr)\u001b[0m\n\u001b[1;32m   1229\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdedup_referents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitr\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1230\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mHashableWrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_referent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mx\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mitr\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/core.py\u001b[0m in \u001b[0;36m<dictcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m   1229\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mdedup_referents\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mitr\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mIterable\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mList\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mAny\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1230\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mHashableWrapper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mget_referent\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mx\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mitr\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1231\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m   1801\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1802\u001b[0;31m   sizes = core.dedup_referents(_get_axis_size(name, np.shape(x), d)\n\u001b[0m\u001b[1;32m   1803\u001b[0m                                for x, d in zip(vals, dims) if d is not None)\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/api.py\u001b[0m in \u001b[0;36m_get_axis_size\u001b[0;34m(name, shape, axis)\u001b[0m\n\u001b[1;32m   1796\u001b[0m       \u001b[0;31m# TODO(mattjj): better error message here\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1797\u001b[0;31m       raise ValueError(\n\u001b[0m\u001b[1;32m   1798\u001b[0m           \u001b[0;34mf\"{name} was requested to map its argument along axis {axis}, \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;31mUnfilteredStackTrace\u001b[0m: ValueError: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())\n\nThe stack trace below excludes JAX-internal frames.\nThe preceding is the original exception that occurred, unmodified.\n\n--------------------",
            "\nThe above exception was the direct cause of the following exception:\n",
            "\u001b[0;31mValueError\u001b[0m                                Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-69-06fa7acb6de1>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvmap\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrotate_point\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0min_axes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpoints_t\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpi\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;31mValueError\u001b[0m: vmap was requested to map its argument along axis 0, which implies that its rank should be at least 1, but is only 0 (its shape is ())"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_leaves(points)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "bF6B-TE90r27",
        "outputId": "3a5c0d81-abac-40c6-f3a7-f1c1d6582843"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[0.0, 0.0, 3.0, 0.0, 0.0, 4.0]"
            ]
          },
          "metadata": {},
          "execution_count": 70
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_leaves(points_t)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "X33eR7sw1Kv0",
        "outputId": "59d8bd8c-486d-46a8-d4e4-d585e3e87a8f"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[0.0, 3.0, 0.0, 0.0, 0.0, 4.0]"
            ]
          },
          "metadata": {},
          "execution_count": 71
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "points_t_a = jax.tree_util.tree_map(lambda p: Point(jnp.array(p.x),jnp.array(p.y)) , points_t)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 242
        },
        "id": "ys0iVfc91MLp",
        "outputId": "b6968503-3924-4733-fa5b-7b2609e6d5f6"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "error",
          "ename": "AttributeError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mAttributeError\u001b[0m                            Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-72-e5c7e91eded8>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpoints_t_a\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mPoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mpoints_t\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/tree_util.py\u001b[0m in \u001b[0;36mtree_map\u001b[0;34m(f, tree, is_leaf, *rest)\u001b[0m\n\u001b[1;32m    207\u001b[0m   \u001b[0mleaves\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtreedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    208\u001b[0m   \u001b[0mall_leaves\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mleaves\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten_up_to\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrest\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 209\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mxs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mxs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mall_leaves\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    210\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    211\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbuild_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtreedef\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mPyTreeDef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/tree_util.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    207\u001b[0m   \u001b[0mleaves\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtreedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    208\u001b[0m   \u001b[0mall_leaves\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mleaves\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten_up_to\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrest\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 209\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mxs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mxs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mall_leaves\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    210\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    211\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbuild_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtreedef\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mPyTreeDef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-72-e5c7e91eded8>\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(p)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mpoints_t_a\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_util\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mPoint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0my\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mpoints_t\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;31mAttributeError\u001b[0m: 'float' object has no attribute 'x'"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "points_t_array = Point(jnp.array(points_t.x),jnp.array(points_t.y))"
      ],
      "metadata": {
        "id": "lwJnGVNl2Im4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "points_t_array"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "iFuafV_G2bwq",
        "outputId": "3f9807f4-a54a-47d2-fcfc-4fec826ba753"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Point(x=Array([0., 3., 0.], dtype=float32), y=Array([0., 0., 4.], dtype=float32))"
            ]
          },
          "metadata": {},
          "execution_count": 74
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.vmap(rotate_point, in_axes=(0, None))(points_t_array, math.pi)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Nw_52_9G2cnR",
        "outputId": "8bdbe2ff-809f-4fb3-a67e-344523fc2990"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "Point(x=Array([-0.0000000e+00, -3.0000000e+00, -4.8985874e-16], dtype=float32), y=Array([ 0.0000000e+00,  3.6739406e-16, -4.0000000e+00], dtype=float32))"
            ]
          },
          "metadata": {},
          "execution_count": 75
        }
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "## Custom nodes"
      ],
      "metadata": {
        "id": "hq3xbiVAPYrx"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "class Layer:\n",
        "  def __init__(self, name, w, b):\n",
        "    self.w = w\n",
        "    self.b = b\n",
        "    self.name = 'name'\n"
      ],
      "metadata": {
        "id": "xG2GU6ZYPWnk"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "h1 = Layer('hidden1', jnp.zeros((100,20)), jnp.zeros((20,)))"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "IFzO_EEhP0IG",
        "outputId": "3100ec45-ea75-4ea1-efb0-6385e6ceeb1a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "pt = [\n",
        "    jnp.ones(50),\n",
        "    h1\n",
        "]"
      ],
      "metadata": {
        "id": "MpbWkfCPQDYB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_leaves(pt)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "qMIIxW18QM4t",
        "outputId": "b6f523f8-6d47-4963-99b0-6107c463c306"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
              "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
              "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32),\n",
              " <__main__.Layer at 0x7fc87db04490>]"
            ]
          },
          "metadata": {},
          "execution_count": 5
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_map(lambda x: x*10, pt)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 242
        },
        "id": "sYoRTuV4QSks",
        "outputId": "ecc5e2d4-fc8d-4efb-dc1c-f7c70015cb6a"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "error",
          "ename": "TypeError",
          "evalue": "ignored",
          "traceback": [
            "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m                                 Traceback (most recent call last)",
            "\u001b[0;32m<ipython-input-6-648c933749c2>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/tree_util.py\u001b[0m in \u001b[0;36mtree_map\u001b[0;34m(f, tree, is_leaf, *rest)\u001b[0m\n\u001b[1;32m    207\u001b[0m   \u001b[0mleaves\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtreedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    208\u001b[0m   \u001b[0mall_leaves\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mleaves\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten_up_to\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrest\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 209\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mxs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mxs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mall_leaves\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    210\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    211\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbuild_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtreedef\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mPyTreeDef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m/usr/local/lib/python3.9/dist-packages/jax/_src/tree_util.py\u001b[0m in \u001b[0;36m<genexpr>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m    207\u001b[0m   \u001b[0mleaves\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtreedef\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtree_flatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtree\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_leaf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    208\u001b[0m   \u001b[0mall_leaves\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mleaves\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflatten_up_to\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mr\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mr\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrest\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 209\u001b[0;31m   \u001b[0;32mreturn\u001b[0m \u001b[0mtreedef\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munflatten\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mxs\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mxs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mzip\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0mall_leaves\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    210\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    211\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mbuild_tree\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtreedef\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mPyTreeDef\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxs\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mAny\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
            "\u001b[0;32m<ipython-input-6-648c933749c2>\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(x)\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mjax\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtree_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0;36m10\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mpt\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
            "\u001b[0;31mTypeError\u001b[0m: unsupported operand type(s) for *: 'Layer' and 'int'"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "def flatten_layer(container):\n",
        "  flat_contents = [container.w, container.b]\n",
        "  aux_data = container.name\n",
        "  return flat_contents, aux_data\n",
        "\n",
        "def unflatten_layer(aux_data, flat_contents):\n",
        "  return Layer(aux_data, *flat_contents)"
      ],
      "metadata": {
        "id": "c9Mc4JnUQZP-"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.register_pytree_node(\n",
        "    Layer, flatten_layer, unflatten_layer)"
      ],
      "metadata": {
        "id": "ME__HdtxQ1_k"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_leaves(pt)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ZYKD_4yZRCXS",
        "outputId": "6c68c41a-24d0-487b-81fe-e3ebefa35d5b"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
              "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
              "        1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],      dtype=float32),\n",
              " Array([[0., 0., 0., ..., 0., 0., 0.],\n",
              "        [0., 0., 0., ..., 0., 0., 0.],\n",
              "        [0., 0., 0., ..., 0., 0., 0.],\n",
              "        ...,\n",
              "        [0., 0., 0., ..., 0., 0., 0.],\n",
              "        [0., 0., 0., ..., 0., 0., 0.],\n",
              "        [0., 0., 0., ..., 0., 0., 0.]], dtype=float32),\n",
              " Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n",
              "        0., 0., 0.], dtype=float32)]"
            ]
          },
          "metadata": {},
          "execution_count": 9
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "pt2 = jax.tree_map(lambda x: x+1, pt)"
      ],
      "metadata": {
        "id": "5d4u6kjSREbR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "jax.tree_util.tree_leaves(pt2)"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "0PMICLdBR4zP",
        "outputId": "7e8ca5c9-72cf-48dc-d190-162185f0dff2"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "[Array([2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,\n",
              "        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.,\n",
              "        2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2., 2.],      dtype=float32),\n",
              " Array([[1., 1., 1., ..., 1., 1., 1.],\n",
              "        [1., 1., 1., ..., 1., 1., 1.],\n",
              "        [1., 1., 1., ..., 1., 1., 1.],\n",
              "        ...,\n",
              "        [1., 1., 1., ..., 1., 1., 1.],\n",
              "        [1., 1., 1., ..., 1., 1., 1.],\n",
              "        [1., 1., 1., ..., 1., 1., 1.]], dtype=float32),\n",
              " Array([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,\n",
              "        1., 1., 1.], dtype=float32)]"
            ]
          },
          "metadata": {},
          "execution_count": 14
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "NpphfhrmR9ZZ"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}